#!/usr/bin/env python3
# encoding: utf-8

# Copyright 2018 Nagoya University (Tomoki Hayashi)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

from __future__ import print_function
from __future__ import unicode_literals

import argparse
import codecs
import json
import logging
import sys

from distutils.util import strtobool

from espnet.utils.cli_utils import get_commandline_args

is_python2 = sys.version_info[0] == 2


def get_parser():
    parser = argparse.ArgumentParser(
        description='add multiple json values to an input or output value',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('jsons', type=str, nargs='+',
                        help='json files')
    parser.add_argument('-i', '--is-input', default=True, type=strtobool,
                        help='If true, add to input. If false, add to output')
    parser.add_argument('--verbose', '-V', default=0, type=int,
                        help='Verbose option')
    return parser


if __name__ == '__main__':
    parser = get_parser()
    args = parser.parse_args()

    # logging info
    logfmt = '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s'
    if args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(
            level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    # make intersection set for utterance keys
    js = []
    intersec_ks = []
    for x in args.jsons:
        with codecs.open(x, 'r', encoding="utf-8") as f:
            j = json.load(f)
        ks = j['utts'].keys()
        logging.info(x + ': has ' + str(len(ks)) + ' utterances')
        if len(intersec_ks) > 0:
            intersec_ks = intersec_ks.intersection(set(ks))
            if len(intersec_ks) == 0:
                logging.warning("Empty intersection")
                break
        else:
            intersec_ks = set(ks)
        js.append(j)
    logging.info('new json has ' + str(len(intersec_ks)) + ' utterances')

    # updated original dict to keep intersection
    intersec_org_dic = dict()
    for k in intersec_ks:
        v = js[0]['utts'][k]
        intersec_org_dic[k] = v

    intersec_add_dic = dict()
    for k in intersec_ks:
        v = js[1]['utts'][k]
        for j in js[2:]:
            v.update(j['utts'][k])
        intersec_add_dic[k] = v

    new_dic = dict()
    for key_id in intersec_org_dic:
        orgdic = intersec_org_dic[key_id]
        adddic = intersec_add_dic[key_id]

        if 'utt2spk' not in orgdic:
            orgdic['utt2spk'] = ''
        # NOTE: for machine translation

        # add as input
        if args.is_input:
            # original input
            input_list = orgdic['input']
            # additional input
            in_add_dic = {}
            if 'idim' in adddic and 'ilen' in adddic:
                in_add_dic['shape'] = [int(adddic['ilen']),
                                       int(adddic['idim'])]
            elif 'idim' in adddic:
                in_add_dic['shape'] = [int(adddic['idim'])]
            # add all other key value
            for key, value in adddic.items():
                if key in ['idim', 'ilen']:
                    continue
                in_add_dic[key] = value
            # add name
            in_add_dic['name'] = 'input%d' % (len(input_list) + 1)

            input_list.append(in_add_dic)
            new_dic[key_id] = {'input': input_list,
                               'output': orgdic['output'],
                               'utt2spk': orgdic['utt2spk']}
        # add as output
        else:
            # original output
            output_list = orgdic['output']
            # additional output
            out_add_dic = {}
            # add shape
            if 'odim' in adddic and 'olen' in adddic:
                out_add_dic['shape'] = [int(adddic['olen']),
                                        int(adddic['odim'])]
            elif 'odim' in adddic:
                out_add_dic['shape'] = [int(adddic['odim'])]
            # add all other key value
            for key, value in adddic.items():
                if key in ['odim', 'olen']:
                    continue
                out_add_dic[key] = value
            # add name
            out_add_dic['name'] = 'target%d' % (len(output_list) + 1)

            output_list.append(out_add_dic)
            new_dic[key_id] = {'input': orgdic['input'],
                               'output': output_list,
                               'utt2spk': orgdic['utt2spk']}
            if 'lang' in orgdic.keys():
                new_dic[key_id]['lang'] = orgdic['lang']

    # ensure "ensure_ascii=False", which is a bug
    jsonstring = json.dumps({'utts': new_dic}, indent=4, ensure_ascii=False,
                            sort_keys=True, separators=(',', ': '))
    sys.stdout = codecs.getwriter("utf-8")(sys.stdout if is_python2 else sys.stdout.buffer)
    print(jsonstring)
